# classification viz
library(tidyverse)
library(data.table)
library(dtplyr)
library(cowplot)
library(tidytext)

### we have removed the raw text columns from these output dataframes
### to comply with Twitter's terms of service

manuscript_preds <- data.table::fread("~/Dropbox/CHAMP-Net/coronavirus_paper/data_and_code/replication_file/output/fullcv_preds.csv")
mir_preds <- data.table::fread("~/Dropbox/CHAMP-Net/coronavirus_paper/data_and_code/replication_file/output/fullcv_MIR_preds.csv")
upsample_preds <- data.table::fread("~/Dropbox/CHAMP-Net/coronavirus_paper/data_and_code/replication_file/output/fullcv_upsample_preds.csv")
allsample_preds <- data.table::fread("~/Dropbox/CHAMP-Net/coronavirus_paper/data_and_code/replication_file/output/allcv_preds.csv")
neutralsample_preds <- data.table::fread("~/Dropbox/CHAMP-Net/coronavirus_paper/data_and_code/replication_file/output/neutralcv_preds.csv")

manuscript_preds$date <- lubridate::as_date(manuscript_preds$created_at)
mir_preds$date <- lubridate::as_date(mir_preds$created_at)
upsample_preds$date <- lubridate::as_date(upsample_preds$created_at)
allsample_preds$date <- lubridate::as_date(allsample_preds$created_at)
neutralsample_preds$date <- lubridate::as_date(neutralsample_preds$created_at)

first_monday <- lubridate::ceiling_date(lubridate::ymd("2020-1-01"), unit = "week") 

make_panel_a <- function(data, title, cap, fm = first_monday){
  
  weeklypred <- lazy_dt(data) %>%
    mutate(week = lubridate::week(date)) %>%
    mutate(weekdate = fm + lubridate::dweeks(week-1)) %>%
    group_by(week,weekdate, fold) %>% 
    summarise(n = n(),
              n_test_dem = sum(caucus == "D"),
              n_test_rep = sum(caucus == "R"),
              acc = mean(correct)) %>%
    mutate(pct_dem = n_test_dem / (n_test_dem + n_test_rep)) %>%
    as_tibble()
  
  median_acc <- weeklypred %>% 
    mutate(weekdate = fm + lubridate::dweeks(week-1)) %>%
    filter(week %in% c(4:13)) %>%
    plyr::ddply(~weekdate,
                summarize,
                acc = median(acc)) %>%
    mutate(week = 4:13, 
           variable = title)
  
  plot <- weeklypred %>%
    filter(week > 3) %>%
    ggplot(aes(x = weekdate, y = acc, 
               group = week,
               col = pct_dem)) +
    ggbeeswarm::geom_quasirandom(size = 3, alpha = I(2/3)) + 
    facet_wrap(~variable) +
    geom_point(data = median_acc, aes(x = weekdate, y = acc, group = week), col = "black", size = 5, shape = 4, stroke = 2) +
    geom_line(data = median_acc, aes(x = weekdate, y = acc, group = variable), col = "black", size = 1.15) + 
    geom_hline(yintercept = .5, col = "black", size = 1.5, linetype = "dashed") +
    scale_color_gradient(name = "Democratic Share",
                         low = "red", high = "blue") +
    labs(x = "", y = "Fit",
         caption = cap) + 
    #scale_y_continuous(limits = c(0.5, 1)) +
    theme_bw() +
    scale_x_date(breaks = lubridate::as_date(c("2020-01-26", "2020-02-09",
                                               "2020-02-23","2020-03-08",
                                               "2020-03-22")),
                 labels = c("Jan 26", "Feb 09", 
                            "Feb 23", "Mar 08",
                            "Mar 22")) + 
    theme(axis.text = element_text(size = 12),
          strip.text= element_text(size = 12),
          #legend.title = element_text(size = 8),
          legend.position = "none", 
          plot.caption = element_text(size = 12, face = "bold", hjust = 0)) 
  
  return(plot)
}

make_panel_b <- function(data, title, .y, ylab, cap){
  
  .y <- enquo(.y)
  
  pp <- 
    lazy_dt(data) %>%
    group_by(id, nominate.dim1, caucus) %>%
    summarise(mpred = median(!!.y)) %>%
    as_tibble()
  
  pp$in_band <- sapply(1:nrow(pp), function(r){
      ifelse(pp$caucus[r] == "R" & pp$mpred[r] < max(pp$mpred[pp$caucus == "D"]) | pp$caucus[r] == "D" & pp$mpred[r] > min(pp$mpred[pp$caucus == "R"]), 1, 0)
  })
  
  pp$title <- title
  
  plot <- pp %>% 
    ggplot(aes(x = nominate.dim1, y = mpred, col = caucus, group = caucus))+
    geom_point(size = 3.5, alpha = I(1/3)) +
    facet_wrap(~title) + 
    annotate("rect", 
             xmin = -Inf, 
             xmax = Inf, 
             ymin = min(pp$mpred[pp$caucus == "R"]) - 0.01, 
             ymax = max(pp$mpred[pp$caucus == "D"]) + 0.01, 
             fill = "gray10", alpha = I(1/4), color = NA) + 
    scale_color_manual(name = "Caucus",
                       breaks = c("D","R"),
                       values = c("darkblue","darkred"),
                       labels = c("Democratic","Republican")) +
    theme_bw() + 
    theme(legend.title = element_blank(), legend.position = c(0.15, 0.8),
          plot.caption = element_text(size = 12, face = "bold", hjust = 0),
          axis.text = element_text(size = 12),
          axis.title = element_text(size = 12),
          legend.text = element_text(size=11),
          strip.text = element_text(size = 12)) +  
    labs(x = "Ideology (DW-NOMINATE first dimension)",
         y = ylab,
         caption = cap)
  
  return(plot)
}

make_panel_c <- function(data, plot_title,
                         cap = "Dashed lines indicate no-information rate of recall by week\nSolid lines indicate observed rate of recall by week",
                         fm = first_monday){
  weeklypred_party <- lazy_dt(data) %>%
    mutate(week = lubridate::week(date)) %>%
    mutate(weekdate = fm + lubridate::dweeks(week-1)) %>%
    group_by(week,weekdate, caucus) %>% 
    summarise(avg_n = n() / 15,
              acc = mean(correct)) %>%
    group_by(week) %>%
    mutate(share_of_test_set = avg_n / sum(avg_n)) %>%
    as_tibble()
  
  plot <- weeklypred_party %>%
    mutate(party = factor(ifelse(caucus == "D", "Democrats", "Republicans"), 
                          levels = c("Republicans","Democrats"))) %>%
    filter(week %in% c(4:13)) %>%
    ggplot(aes(x = weekdate, col = party, fill = party))+
    facet_wrap(~party, nrow = 1, ncol = 2) +
    geom_ribbon(aes(ymin = share_of_test_set, 
                    ymax = acc, 
                    alpha = .3)) +
    geom_line(aes(y = acc, lty = "1observed"), size = 1.05) +
    geom_line(aes(y = share_of_test_set, lty = "2noinfo"), size = 1.05) +
    scale_linetype_manual(name = "",
                          breaks = c("2noinfo","1observed"),
                          values = c("dashed","solid"),
                          labels = c( "No-Information Recall","Observed Recall")) +
    guides(alpha = FALSE, linetype = FALSE, fill = FALSE)+
    labs(x = "", y = "Recall",
         title = plot_title,
         caption = cap)+
    theme_bw()+
    scale_x_date(breaks = lubridate::as_date(c("2020-01-26", "2020-02-09",
                                               "2020-02-23","2020-03-08",
                                               "2020-03-22")),
                 labels = c("Jan 26", "Feb 09", 
                            "Feb 23", "Mar 08",
                            "Mar 22"))+
    scale_fill_manual(values = c("darkred", "darkblue")) + 
    scale_color_manual(values = c("darkred", "darkblue")) + 
    theme(plot.title = element_text(size = 24),
          legend.position = "none",
          plot.caption = element_text(size = 12, face = "bold", hjust = 0),
          axis.text = element_text(size = 12),
          axis.title = element_text(size = 12),
          legend.text = element_text(size=11),
          strip.text = element_text(size = 12))
  
    return(plot)
}

# function to visualize rf variable importance
rf_to_viz <- function(rf_mod, out, xvars, sub, 
                      classification = FALSE, 
                      shrink = NULL){
  
  # get variable importance into data frame
  vimp_lpc <- 
    rf_mod$variable.importance %>%
    data.frame()
  
  names(vimp_lpc) <- "importance"
  vimp_lpc$var <- rownames(vimp_lpc)
  
  vimp_lpc <- vimp_lpc %>% filter(!var == "day_relative_1120")
  
  if(!is.null(shrink)){
    vimp_lpc <- vimp_lpc %>% filter(importance > shrink)
  }
  
  if(classification == FALSE){
    vimp_lpc_plot <- 
      vimp_lpc %>%
      ggplot(aes(x=reorder(var,importance), 
                 y=importance))+ 
      geom_bar(stat="identity", 
               position="dodge")+ 
      coord_flip()+
      ylab("Variable importance (permutation)")+
      xlab(xvars)+
      labs(title = paste0(out),
           subtitle = sub,
           caption = paste0("OOB R-Squared: ", round(rf_mod$r.squared, 3)))+
      guides(fill=F)+
      theme_bw()
  }
  
  if(classification == TRUE){
    vimp_lpc_plot <- 
      vimp_lpc %>%
      ggplot(aes(x=reorder(var,importance), 
                 y=importance))+ 
      geom_bar(stat="identity", 
               position="dodge")+ 
      coord_flip()+
      ylab("Variable importance (permutation)")+
      xlab(xvars)+
      labs(title = paste0(out),
           subtitle = sub,
           caption = paste0("OOB Prediction Error: ", round(rf_mod$prediction.error, 3)))+
      guides(fill=F)+
      theme_bw()
  }
  return(vimp_lpc_plot)
}


## Load in the data for figure 1
load("~/Dropbox/CHAMP-Net/coronavirus_paper/data_and_code/replication_file/output/figure_1a_data.Rda")
load("~/Dropbox/CHAMP-Net/coronavirus_paper/data_and_code/replication_file/output/figure_1b_data.Rda")

## Figure 1a
cumulative_tweets <- ggplot(cumulative_data, aes(x = date, y = cumulative_sum, col = party)) + 
  geom_step(size = 1.5) + 
  theme_bw() + 
  scale_color_manual(values = c(# "darkorchid4",
    "darkblue", "darkred","#2e8b57")) + 
  theme(legend.position = c(0.125, 0.85), legend.title = element_blank(),
        plot.caption = element_text(size = 12, face = "bold", hjust = 0),
        axis.text = element_text(size = 12),
        axis.title = element_text(size = 12),
        legend.text = element_text(size=11)) +
  labs(y = "Cumulative count", x = "", caption = "(a) Cumulative tweets, infections, and deaths on COVID-19") + 
  geom_vline(xintercept = c(as.Date("2020-02-28"), as.Date("2020-03-13")), size = 1, linetype = "dashed")  + 
  annotate(geom = "text", x = as.Date("2020-02-24"), y = max(cumulative_data$cumulative_sum, na.rm = T)/2, label = "CDC identifies first case of community spread",
           angle = 90, vjust = 1.25) + 
  annotate(geom = "text", x = as.Date("2020-03-09"), y = max(cumulative_data$cumulative_sum, na.rm = T)/2, label = "National emergency declared",
           angle = 90, vjust = 1.25) +
  scale_y_continuous(labels = scales::comma)

figure_1b_data <- all_words_compare

## Figure 1b
barplot_abs_diff <- ggplot(figure_1b_data,  aes(x = reorder(word, pol_diff), y = pol_diff, fill = party)) + 
  geom_col() + 
  theme_bw() + 
  coord_flip() + 
  scale_fill_manual(values = c("darkred", "darkblue")) + 
  theme(legend.position = "none", legend.title = element_blank(),
        plot.caption = element_text(size = 12, face = "bold", hjust = 0),
        axis.text = element_text(size = 12),
        axis.title = element_text(size = 12)) + 
  labs(x = "", y = "", caption = "(b) Absolute difference in words used by party") + 
  scale_y_continuous(limits = c(-.13, .13), 
                     breaks = c(-0.1, -0.05, 0, .05, 0.1), 
                     labels = c("10% more\n Democratic",
                                "5% more\n Democratic",
                                "Same",
                                "5% more\n Republican",
                                "10% more\n Republican"))


manuscript_figure_1 <- gridExtra::grid.arrange(cumulative_tweets, barplot_abs_diff, ncol = 2)
ggsave("~/Dropbox/CHAMP-Net/coronavirus_paper/data_and_code/replication_file/output/figure_1.png", manuscript_figure_1, width = 12.5, height = 6.25, units = "in")
ggsave("~/Dropbox/CHAMP-Net/coronavirus_paper/data_and_code/replication_file/output/figure_1.pdf", manuscript_figure_1, width = 12.5, height = 6.25, units = "in")

manuscript_2a <- make_panel_a(data = manuscript_preds, 
                              title ="Classification Accuracy by Week",
                              cap = "(a) Points represent accuracy rate in the given week\n      for each fold of cross-validation")
manuscript_2b <- make_panel_b(data = manuscript_preds, 
                              title ="Partisan communication and roll call voting",
                              .y = preds,
                              ylab = "Median p(Republican) of COVID-19 Tweets",
                              cap = "(b) Political ideology and predicted probability")
manuscript_2c <- make_panel_c(data = manuscript_preds, 
                              plot_title ="",
                              cap = "(c) Recall above no-information rate for Republicans and Democrats")

fig2 <- plot_grid(plot_grid(manuscript_2a, manuscript_2b, axis = "bt", ncol=2, align = "h"), manuscript_2c, nrow = 2)

ggsave("~/Dropbox/CHAMP-Net/coronavirus_paper/data_and_code/replication_file/output/figure_2.png", fig2, width = 10, height = 11, units = "in")
ggsave("~/Dropbox/CHAMP-Net/coronavirus_paper/data_and_code/replication_file/output/figure_2.pdf", fig2, width = 10, height = 11, units = "in")

s1_rf <- make_panel_a(data = manuscript_preds,
                      title = "Random forest classification accuracy by week",
                      cap = "")
s1_mir <- make_panel_a(data = mir_preds,
                      title = "Multinomial inverse regression classification accuracy by week",
                      cap = "")

fig_s1 <- gridExtra::grid.arrange(s1_mir, s1_rf, nrow = 1, ncol = 2)
ggsave(fig_s1, file = "~/Dropbox/CHAMP-Net/coronavirus_paper/data_and_code/replication_file/output/figure_s1.png", width = 12, height = 8)
ggsave(fig_s1, file = "~/Dropbox/CHAMP-Net/coronavirus_paper/data_and_code/replication_file/output/figure_s1.pdf", width = 12, height = 8)

s2_rf <- make_panel_c(data = manuscript_preds,
                      plot_title = "Random forest",
                      cap = "")
s2_mir <- make_panel_c(data = mir_preds,
                       plot_title = "Multinomial inverse regression",
                       cap = "Random forest and MIR Recall above no-information rate for Republicans and Democrats")

fig_s2 <- gridExtra::grid.arrange(s2_rf, s2_mir, nrow = 2, ncol = 1)
ggsave(fig_s2, file = "~/Dropbox/CHAMP-Net/coronavirus_paper/data_and_code/replication_file/output/figure_s2.png", width = 12, height = 8)
ggsave(fig_s2, file = "~/Dropbox/CHAMP-Net/coronavirus_paper/data_and_code/replication_file/output/figure_s2.pdf", width = 12, height = 8)


s3_rf <- make_panel_b(data = manuscript_preds, 
                      title ="Random Forest",
                      .y = preds,
                      ylab = "Median p(Republican)",
                      cap = "")
s3_mir <- make_panel_b(data = mir_preds, 
                      title ="Multinomial Inverse Regression",
                      .y = Y_train,
                      ylab = "Median coefficient projetion",
                      cap = "")

fig_s3 <- gridExtra::grid.arrange(s3_rf, s3_mir, nrow = 1, ncol = 2)
ggsave(fig_s3, file = "~/Dropbox/CHAMP-Net/coronavirus_paper/data_and_code/replication_file/output/figure_s3.png", width = 12, height = 8)
ggsave(fig_s3, file = "~/Dropbox/CHAMP-Net/coronavirus_paper/data_and_code/replication_file/output/figure_s3.pdf", width = 12, height = 8)


s4_acc <- make_panel_a(data = neutralsample_preds,
                       title = "",
                       cap = "(a) Points represent accuracy rate in the given week\n     for each fold of cross-validation")+
  theme(strip.background = element_blank(),
        strip.text.x = element_blank())

load("~/Dropbox/CHAMP-Net/coronavirus_paper/data_and_code/replication_file/output/NeutralDiffInUse.rda")

s4_absdiff_dat <- all_words_compare
s4_absdiff <- s4_absdiff_dat %>%
  ggplot(aes(x = reorder(word, pol_diff), y = pol_diff, fill = party)) + 
  geom_col() + 
  theme_bw() + 
  coord_flip() + 
  scale_fill_manual(values = c("darkred", "darkblue")) + 
  theme(legend.position = "none", legend.title = element_blank(),
        plot.caption = element_text(size = 12, face = "bold", hjust = 0),
        axis.text = element_text(size = 12),
        axis.title = element_text(size = 12)) + 
  labs(x = "", y = "", caption = "(b) Absolute difference in words used\n     by party") + 
  scale_y_continuous(limits = c(-.12, .12), 
                     breaks = c(-0.1, -0.05, 0, .05, 0.1), 
                     labels = c("10% more\n Democratic",
                                "5% more\n Democratic",
                                "Same",
                                "5% more\n Republican",
                                "10% more\n Republican"))

s4_recall <- make_panel_c(data = neutralsample_preds,
                          plot_title = "",
                          cap = "(c) Recall above no-information rate for Republicans and Democrats")
fig_s4 <- plot_grid(plot_grid(s4_acc, s4_absdiff, axis = "bt", ncol=2, align = "h"), s4_recall, nrow = 2)

ggsave("~/Dropbox/CHAMP-Net/coronavirus_paper/data_and_code/replication_file/output/figure_s4.png", fig_s4, width = 16, height = 11, units = "in")
ggsave("~/Dropbox/CHAMP-Net/coronavirus_paper/data_and_code/replication_file/output/figure_s4.pdf", fig_s4, width = 16, height = 11, units = "in")

s5_acc <- make_panel_a(data = allsample_preds,
                       title = "",
                       cap = "(a) Points represent accuracy rate in the given week\n     for each fold of cross-validation")+
  theme(strip.background = element_blank(),
        strip.text.x = element_blank())

s5_absdiff_dat <- readr::read_csv("~/Dropbox/CHAMP-Net/coronavirus_paper/data_and_code/replication_file/output/random_sample_abs_diff.csv")
s5_absdiff <- s5_absdiff_dat %>%
  ggplot(aes(x = reorder(word, pol_diff), y = pol_diff, fill = factor(party, levels = c("Republican","Democrat")))) + 
  geom_col() + 
  theme_bw() + 
  coord_flip() + 
  scale_fill_manual(values = c("darkred", "darkblue")) + 
  theme(legend.position = "none", legend.title = element_blank(),
        plot.caption = element_text(size = 12, face = "bold", hjust = 0),
        axis.text = element_text(size = 12),
        axis.title = element_text(size = 12)) + 
  labs(x = "", y = "", caption = "(b) Absolute difference in words used\n     by party") + 
  scale_y_continuous(limits = c(-.12, .12), 
                     breaks = c(-0.1, -0.05, 0, .05, 0.1), 
                     labels = c("10% more\n Democratic",
                                "5% more\n Democratic",
                                "Same",
                                "5% more\n Republican",
                                "10% more\n Republican"))

s5_recall <- make_panel_c(data = allsample_preds,
                          plot_title = "",
                          cap = "(c) Recall above no-information rate for Republicans and Democrats")
fig_s5 <- plot_grid(plot_grid(s5_acc, s5_absdiff, axis = "bt", ncol=2, align = "h"), s5_recall, nrow = 2)

ggsave("~/Dropbox/CHAMP-Net/coronavirus_paper/data_and_code/replication_file/output/figure_s5.png", fig_s5, width = 16, height = 11, units = "in")
ggsave("~/Dropbox/CHAMP-Net/coronavirus_paper/data_and_code/replication_file/output/figure_s5.pdf", fig_s5, width = 16, height = 11, units = "in")

# make figure 6
load("~/Dropbox/CHAMP-Net/coronavirus_paper/data_and_code/replication_file/output/covid_token_importance.RData")

figure_s6 <-
  vimp_lpc %>% left_join(tmp[,c("word","Y","freq")], 
                         by = c("var" = "word")) %>%
  filter(!var == "day_relative_1120") %>%
  filter(importance > .001) %>%
  mutate(fillcol = case_when(Y > 0 ~ "red",
                             Y < 0 ~ "blue",
                             Y == 0 | is.na(Y) ~ "black"),
         val = ifelse(fillcol == "blue", -1*importance, importance)) %>%
  ggplot(aes(x=reorder(var,val), 
             y=val,
             fill = factor(fillcol, levels = c("blue","red","black"))))+ 
  geom_bar(stat="identity", 
           position="dodge")+ 
  scale_fill_manual(name = "Generally indicates author is a...",
                    breaks = c("blue", "red","black"),
                    values = c("darkblue","darkred","black"),
                    labels = c("Democrat","Republican","No Direction"))+
  coord_flip()+
  scale_y_continuous(name = "Random Forest Permutation Importance",
                     breaks = seq(from = -.0075, to = .0075, by = .0025),
                     labels = paste0(abs(seq(from = -.0075, to = .0075, by = .0025))))+
  xlab("Features")+
  labs(title = "Most Important Text Features",
       subtitle = "Partisan association defined by sign of multinomial inverse regression coefficient",
       caption = paste0("OOB Prediction Error: ", round(rf1$prediction.error, 3)))+
  theme_bw()+
  theme(text = element_text(family = "serif"),
        plot.title = element_text(size = 24),
        plot.subtitle = element_text(size =14),
        plot.caption = element_text(size = 12, face = "bold", hjust = 0),
        axis.text = element_text(size = 12),
        axis.title = element_text(size = 12))
ggsave(figure_s6, file = "~/Dropbox/CHAMP-Net/coronavirus_paper/data_and_code/replication_file/output/figure_s6.png", width = 12, height = 6)

